{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "245ab07a-fb2f-4cf4-ab9a-5c05a9b44daa",
   "metadata": {},
   "source": [
    "# LangChain retrieval knowledge base Q&A based on Qwen-7B-Chat"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e8df2cb7-a69c-4231-9596-4c871d893633",
   "metadata": {},
   "source": [
    "This notebook introduces a question-answering application based on a local knowledge base using Qwen-7B-Chat with langchain. The goal is to establish a knowledge base Q&A solution that is friendly to many scenarios and open-source models, and that can run offline. The implementation process of this project includes loading files -> reading text -> segmenting text -> vectorizing text -> vectorizing questions -> matching the top k most similar text vectors with the question vectors -> incorporating the matched text as context along with the question into the prompt -> submitting to the LLM (Large Language Model) to generate an answer."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92e9c81a-45c7-4c12-91af-3c5dd52f63bb",
   "metadata": {},
   "source": [
    "## Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84cfcf88-3bef-4412-a658-4eaefeb6502a",
   "metadata": {},
   "source": [
    "Download Qwen-7B-Chat\n",
    "\n",
    "Firstly, we need to download the model. You can use the snapshot_download that comes with modelscope to download the model to a specified directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c1f9ded-8035-42c7-82c7-444ce06572bc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install modelscope"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c26225c-c958-429e-b81d-2de9820670c2",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from modelscope.hub.snapshot_download import snapshot_download\n",
    "snapshot_download(\"Qwen/Qwen-7B-Chat\",cache_dir='/tmp/models') "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8f51796-49fa-467d-a825-ae9a281eb3fd",
   "metadata": {},
   "source": [
    "Download the dependencies for langchain and Qwen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87fe1023-644f-4610-afaf-0b7cddc30d60",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install langchain==0.0.187 dashscope==1.0.4 sentencepiece==0.1.99 cpm_kernels==1.0.11 nltk==3.8.1 sentence_transformers==2.2.2 unstructured==0.6.5 faiss-cpu==1.7.4 icetk==0.0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "853cdfa4-a2ce-4baa-919a-b9e2aecd2706",
   "metadata": {},
   "source": [
    "Download the retrieval document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ba800dc-311d-4a83-8115-f05b09b39ffd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/qwen_recipes/LLM_Survey_Chinese.pdf.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07e923b3-b7ae-4983-abeb-2ce115566f15",
   "metadata": {},
   "source": [
    "Download the text2vec model, for Chinese in our case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a07cd8d-3cec-40f6-8d2b-eb111aaf1164",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/qwen_recipes/GanymedeNil_text2vec-large-chinese.tar.gz\n",
    "!tar -zxvf GanymedeNil_text2vec-large-chinese.tar.gz -C /tmp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc483af0-170e-4e61-8d25-a336d1592e34",
   "metadata": {},
   "source": [
    "## Try out the model \n",
    "\n",
    "Load the Qwen-7B-Chat model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c112cf82-0447-46c4-9c32-18f243c0a686",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from abc import ABC\n",
    "from langchain.llms.base import LLM\n",
    "from typing import Any, List, Mapping, Optional\n",
    "from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_path=\"/tmp/models/Qwen/Qwen-7B-Chat\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda()\n",
    "model.eval()\n",
    "\n",
    "class Qwen(LLM, ABC):\n",
    "    max_token: int = 10000\n",
    "    temperature: float = 0.01\n",
    "    top_p = 0.9\n",
    "    history_len: int = 3\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @property\n",
    "    def _llm_type(self) -> str:\n",
    "        return \"Qwen\"\n",
    "\n",
    "    @property\n",
    "    def _history_len(self) -> int:\n",
    "        return self.history_len\n",
    "\n",
    "    def set_history_len(self, history_len: int = 10) -> None:\n",
    "        self.history_len = history_len\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        prompt: str,\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "    ) -> str:\n",
    "        response, _ = model.chat(tokenizer, prompt, history=[])\n",
    "        return response\n",
    "    \n",
    "    @property\n",
    "    def _identifying_params(self) -> Mapping[str, Any]:\n",
    "        \"\"\"Get the identifying parameters.\"\"\"\n",
    "        return {\"max_token\": self.max_token,\n",
    "                \"temperature\": self.temperature,\n",
    "                \"top_p\": self.top_p,\n",
    "                \"history_len\": self.history_len}\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "382ed433-870f-424e-b074-210ea6f84b70",
   "metadata": {},
   "source": [
    "Specify the txt file that needs retrieval for knowledge-based Q&A."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14be706b-4a7d-4906-9369-1f03c6c99854",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
    "from typing import List, Tuple\n",
    "import numpy as np\n",
    "from langchain.document_loaders import TextLoader\n",
    "from chinese_text_splitter import ChineseTextSplitter\n",
    "from langchain.docstore.document import Document\n",
    "from langchain.prompts.prompt import PromptTemplate\n",
    "from langchain.chains import RetrievalQA\n",
    "\n",
    "\n",
    "def load_file(filepath, sentence_size=100):\n",
    "    loader = TextLoader(filepath, autodetect_encoding=True)\n",
    "    textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)\n",
    "    docs = loader.load_and_split(textsplitter)\n",
    "    write_check_file(filepath, docs)\n",
    "    return docs\n",
    "\n",
    "\n",
    "def write_check_file(filepath, docs):\n",
    "    folder_path = os.path.join(os.path.dirname(filepath), \"tmp_files\")\n",
    "    if not os.path.exists(folder_path):\n",
    "        os.makedirs(folder_path)\n",
    "    fp = os.path.join(folder_path, 'load_file.txt')\n",
    "    with open(fp, 'a+', encoding='utf-8') as fout:\n",
    "        fout.write(\"filepath=%s,len=%s\" % (filepath, len(docs)))\n",
    "        fout.write('\\n')\n",
    "        for i in docs:\n",
    "            fout.write(str(i))\n",
    "            fout.write('\\n')\n",
    "        fout.close()\n",
    "\n",
    "        \n",
    "def seperate_list(ls: List[int]) -> List[List[int]]:\n",
    "    lists = []\n",
    "    ls1 = [ls[0]]\n",
    "    for i in range(1, len(ls)):\n",
    "        if ls[i - 1] + 1 == ls[i]:\n",
    "            ls1.append(ls[i])\n",
    "        else:\n",
    "            lists.append(ls1)\n",
    "            ls1 = [ls[i]]\n",
    "    lists.append(ls1)\n",
    "    return lists\n",
    "\n",
    "\n",
    "class FAISSWrapper(FAISS):\n",
    "    chunk_size = 250\n",
    "    chunk_conent = True\n",
    "    score_threshold = 0\n",
    "    \n",
    "    def similarity_search_with_score_by_vector(\n",
    "            self, embedding: List[float], k: int = 4\n",
    "    ) -> List[Tuple[Document, float]]:\n",
    "        scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)\n",
    "        docs = []\n",
    "        id_set = set()\n",
    "        store_len = len(self.index_to_docstore_id)\n",
    "        for j, i in enumerate(indices[0]):\n",
    "            if i == -1 or 0 < self.score_threshold < scores[0][j]:\n",
    "                # This happens when not enough docs are returned.\n",
    "                continue\n",
    "            _id = self.index_to_docstore_id[i]\n",
    "            doc = self.docstore.search(_id)\n",
    "            if not self.chunk_conent:\n",
    "                if not isinstance(doc, Document):\n",
    "                    raise ValueError(f\"Could not find document for id {_id}, got {doc}\")\n",
    "                doc.metadata[\"score\"] = int(scores[0][j])\n",
    "                docs.append(doc)\n",
    "                continue\n",
    "            id_set.add(i)\n",
    "            docs_len = len(doc.page_content)\n",
    "            for k in range(1, max(i, store_len - i)):\n",
    "                break_flag = False\n",
    "                for l in [i + k, i - k]:\n",
    "                    if 0 <= l < len(self.index_to_docstore_id):\n",
    "                        _id0 = self.index_to_docstore_id[l]\n",
    "                        doc0 = self.docstore.search(_id0)\n",
    "                        if docs_len + len(doc0.page_content) > self.chunk_size:\n",
    "                            break_flag = True\n",
    "                            break\n",
    "                        elif doc0.metadata[\"source\"] == doc.metadata[\"source\"]:\n",
    "                            docs_len += len(doc0.page_content)\n",
    "                            id_set.add(l)\n",
    "                if break_flag:\n",
    "                    break\n",
    "        if not self.chunk_conent:\n",
    "            return docs\n",
    "        if len(id_set) == 0 and self.score_threshold > 0:\n",
    "            return []\n",
    "        id_list = sorted(list(id_set))\n",
    "        id_lists = seperate_list(id_list)\n",
    "        for id_seq in id_lists:\n",
    "            for id in id_seq:\n",
    "                if id == id_seq[0]:\n",
    "                    _id = self.index_to_docstore_id[id]\n",
    "                    doc = self.docstore.search(_id)\n",
    "                else:\n",
    "                    _id0 = self.index_to_docstore_id[id]\n",
    "                    doc0 = self.docstore.search(_id0)\n",
    "                    doc.page_content += \" \" + doc0.page_content\n",
    "            if not isinstance(doc, Document):\n",
    "                raise ValueError(f\"Could not find document for id {_id}, got {doc}\")\n",
    "            doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])\n",
    "            doc.metadata[\"score\"] = int(doc_score)\n",
    "            docs.append((doc, doc_score))\n",
    "        return docs\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    # load docs\n",
    "    filepath = 'LLM_Survey_Chinese.pdf.txt'\n",
    "    # LLM name\n",
    "    LLM_TYPE = 'qwen'\n",
    "    # Embedding model name\n",
    "    EMBEDDING_MODEL = 'text2vec'\n",
    "    # 基于上下文的prompt模版，请务必保留\"{question}\"和\"{context_str}\"\n",
    "    PROMPT_TEMPLATE = \"\"\"已知信息：\n",
    "    {context_str} \n",
    "    根据上述已知信息，简洁和专业的来回答用户的问题。如果无法从中得到答案，请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”，不允许在答案中添加编造成分，答案请使用中文。 问题是：{question}\"\"\"\n",
    "    # Embedding running device\n",
    "    EMBEDDING_DEVICE = \"cuda\"\n",
    "    # return top-k text chunk from vector store\n",
    "    VECTOR_SEARCH_TOP_K = 3\n",
    "    # 文本分句长度\n",
    "    SENTENCE_SIZE = 50\n",
    "    CHAIN_TYPE = 'stuff'\n",
    "    llm_model_dict = {\n",
    "        \"qwen\": QWen,\n",
    "    }\n",
    "    embedding_model_dict = {\n",
    "        \"text2vec\": \"/tmp/GanymedeNil_text2vec-large-chinese\",\n",
    "    }\n",
    "    print(\"loading model start\")\n",
    "    llm = llm_model_dict[LLM_TYPE]()\n",
    "    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],model_kwargs={'device': EMBEDDING_DEVICE})\n",
    "    print(\"loading model done\")\n",
    "\n",
    "    print(\"loading documents start\")\n",
    "    docs = load_file(filepath, sentence_size=SENTENCE_SIZE)\n",
    "    print(\"loading documents done\")\n",
    "\n",
    "    print(\"embedding start\")\n",
    "    docsearch = FAISSWrapper.from_documents(docs, embeddings)\n",
    "    print(\"embedding done\")\n",
    "\n",
    "    print(\"loading qa start\")\n",
    "    prompt = PromptTemplate(\n",
    "        template=PROMPT_TEMPLATE, input_variables=[\"context_str\", \"question\"]\n",
    "    )\n",
    "\n",
    "    chain_type_kwargs = {\"prompt\": prompt, \"document_variable_name\": \"context_str\"}\n",
    "    qa = RetrievalQA.from_chain_type(\n",
    "        llm=llm,\n",
    "        chain_type=CHAIN_TYPE, \n",
    "        retriever=docsearch.as_retriever(search_kwargs={\"k\": VECTOR_SEARCH_TOP_K}), \n",
    "        chain_type_kwargs=chain_type_kwargs)\n",
    "    print(\"loading qa done\")\n",
    "\n",
    "    query = \"大模型指令微调有好的策略？\"  \n",
    "    print(qa.run(query))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
